import bluepysnap
from scipy.special import expit
import scipy as sc
from scipy.optimize import curve_fit
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt


def get_spindle_amp(sim):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    from scipy.signal import find_peaks

    amps = []
    durations = []
    freqs = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size <= 2:
            amps.append(0)
            durations.append(0)
            freqs.append(0)
        else:
            amps.append(np.mean(f[peaks]))
            durations.append(t[peaks[-1]] - t[peaks[0]])
            freqs.append((len(peaks) - 1) / (t[peaks[-1]] - t[peaks[0]]) * 1000)
    return {"amps": amps, "durations": durations, "freqs": freqs}


def fit_func(_x, shift_exp, slope_exp=4, amp_exp=60):
    return expit((_x - shift_exp) / slope_exp) * amp_exp + 2


if __name__ == "__main__":

    colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C3", "C7", "C8", "C9", "C10"]
    base_path = Path(
        "../../simulations/08_08_Cav_variants/"
    )
    shifts = []
    shift_errs = []
    fig_all = plt.figure(figsize=(5, 3))
    ax_all = plt.gca()
    fig_all_d = plt.figure(figsize=(5, 2))
    ax_all_d = plt.gca()
    fig_all_f = plt.figure(figsize=(5, 2))
    ax_all_f = plt.gca()
    for i, (model, var) in enumerate(
        zip(
            [
                "Ecel",
                "Spp",
                "Runaway",
                "intermedEcel",
                "intermedSpp",
                "intermedRunaway",
                "mixed50pEcel50pSppRunaway",
                "mixed75pEcel25pSppRunaway",
                "mixed25pEcel75pSppRunaway",
            ],
            [
                1,
                2,
                3,
                4,
                5,
                6,
                7,
                8,
                9,
            ],
        )
    ):
        print(model)
        amps_mean = []
        amps_std = []
        durations_mean = []
        durations_std = []
        freqs_mean = []
        freqs_std = []
        percs = [40, 50, 60, 70, 75, 80, 85, 90, 95, 100]
        _percs = []
        for perc in percs:
            sim_path = (
                base_path
                / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
            )
            if not sim_path.exists():
                continue
            _percs.append(perc)
            sim = bluepysnap.Simulation(sim_path)
            data = get_spindle_amp(sim)
            amps_mean.append(np.mean(data["amps"]))
            amps_std.append(np.std(data["amps"]))
            durations_mean.append(np.mean(data["durations"]))
            durations_std.append(np.std(data["durations"]))
            freqs_mean.append(np.mean(data["freqs"]))
            freqs_std.append(np.std(data["freqs"]))

        amps_mean = np.array(amps_mean)
        amps_std = np.array(amps_std)
        popt, pcov = curve_fit(fit_func, _percs, amps_mean, bounds=([50], [120]))
        shift_errs.append(np.sqrt(np.diag(pcov))[0])
        print(popt)
        shifts.append(popt[0])
        if i in [0, 1, 2, 6]:
            ax_all_d.errorbar(_percs, durations_mean, durations_std, c=colors[i], label=model)
            ax_all_f.errorbar(_percs, freqs_mean, freqs_std, c=colors[i], label=model)
            print(amps_std)
            ax_all.errorbar(_percs, amps_mean, amps_std, c=colors[i], ls="")
            ax_all.scatter(_percs, amps_mean, label=model, c=colors[i], s=15)
            x = np.linspace(10, 105, 1000)
            ax_all.plot(x, fit_func(x, *popt), color=colors[i], lw=0.8)

        ax_all.set_ylabel("Mean spindle freq (Hz)")
        ax_all.set_xlabel("percent CT activated")
        ax_all.set_xlim(39, 105)
        ax_all.set_ylim(0, 70)
        ax_all.legend()
        fig_all.tight_layout()
        fig_all.savefig("scan.pdf")

        ax_all_d.set_ylabel("Spindle duration [ms]")
        ax_all_d.set_xlabel("percent CT activated")
        ax_all_d.set_xlim(39, 105)
        ax_all_d.set_ylim(0, 600)
        ax_all_d.legend()
        fig_all_d.tight_layout()
        fig_all_d.savefig("scan_duration.pdf")

        ax_all_f.set_ylabel("Spindle frequency [ms]")
        ax_all_f.set_xlabel("percent CT activated")
        ax_all_f.set_xlim(39, 105)
        ax_all_f.set_ylim(0, 15)
        ax_all_f.legend()
        fig_all_f.tight_layout()
        fig_all_f.savefig("scan_freq.pdf")

    its = [
        0.0005,
        0.0007,
        0.001,
        0.0006,
        0.0008,
        0.0009,
        sc.stats.hmean([0.0005, 0.0007, 0.001], weights=[0.5, 0.25, 0.25]),
        sc.stats.hmean([0.0005, 0.0007, 0.001], weights=[0.75, 0.125, 0.125]),
        sc.stats.hmean([0.0005, 0.0007, 0.001], weights=[0.25, 0.325, 0.325]),
    ]
    print(its)
    print(len(its), len(shifts))
    plt.figure(figsize=(5, 2))
    # names = ["ecel", "spp", "runaway", "inter1", "inter2", "inter3", "mixed1", "mixed2", "mixed3"]
    # for it, shift, name in zip(its[:3], shifts[:3], names[:3]):
    # plt.scatter(it, shift, label=name)
    # for it, shift, name in zip(its[-3:], shifts[-3:], names[-3:]):
    #    plt.scatter(it, shift, label=name)
    # plt.errorbar(its, shifts, shift_errs, ls="", c="k")
    # plt.scatter(its[:-3], shifts[:-3], c="k", s=10)
    x = np.linspace(0.00, 0.0012, 10)
    z = np.polyfit(its[:-3], shifts[:-3], 1, w=1.0 / np.array(shift_errs[:-3]))
    print(z)
    p = np.poly1d(z)
    plt.plot(x, p(x), c="k")
    plt.scatter(its[:-3], shifts[:-3], c="k", label="uniform circuit")
    plt.scatter(its[-3:], shifts[-3:], c="r", label="mixed circuit")
    plt.legend()
    plt.xlabel("it2 conductances")
    plt.ylabel("CT turning point")
    plt.axis([0, 0.0012, 60, 100])
    plt.legend()
    plt.tight_layout()
    plt.savefig("shifts_CT.pdf")
